import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
from sklearn.datasets import load_iris
from sklearn.datasets import load_digits
from sklearn.datasets import load_wine
from sklearn.tree import DecisionTreeClassifier
from csv import reader

fig, ax = plt.subplots()
ax.set_xlabel(r"$\tilde{\alpha}$")
ax.set_ylabel("average misclassification rate")
ax.set_title("Loss vs complexity parameter")

# Load CSV file
def load_csv(filename):
	file = open(filename, "rt")
	lines = reader(file)
	dataset = list(lines)
	return dataset

# Convert str column to float
def str_column_to_float(dataset, column):
	for row in dataset:
		row[column] = float(row[column].strip())

def plot_path(X, y, dataset, rs):

    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rs)
    
    clf = DecisionTreeClassifier(random_state=0)
    path = clf.cost_complexity_pruning_path(X_train, y_train)
    ccp_alphas, impurities = path.ccp_alphas, path.impurities
    
    # fig, ax = plt.subplots()
    # ax.plot(ccp_alphas[:-1], impurities[:-1], marker="o", drawstyle="steps-post")
    # ax.set_xlabel("effective alpha")
    # ax.set_ylabel("total impurity of leaves")
    # ax.set_title("Total Impurity vs effective alpha for training set")
    
    clfs = []
    for ccp_alpha in ccp_alphas:
        clf = DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha)
        clf.fit(X_train, y_train)
        clfs.append(clf)
    
    
    clfs = clfs[:-1]
    ccp_alphas = ccp_alphas[:-1]
    
    
    # train_scores = [clf.score(X_train, y_train) for clf in clfs]
    test_scores = [1-clf.score(X_test, y_test) for clf in clfs]
    ccp_alphas=np.append(ccp_alphas,0.3)
    test_scores.append(test_scores[-1])
    
    # ax.plot(ccp_alphas, train_scores, marker="o", label="train", drawstyle="steps-post")
    ax.plot(ccp_alphas, test_scores, marker="o", label=dataset, drawstyle="steps-post")
    ax.legend()


X, y = load_breast_cancer(return_X_y=True)
plot_path(X, y, 'breast_cancer', 0)
X, y = load_wine(return_X_y=True)
plot_path(X, y, 'wine', 3)
X, y = load_iris(return_X_y=True)
plot_path(X, y, 'iris', 2)

filename = 'data_banknote_authentication.csv'
dataset = load_csv(filename)
# convert string attributes to integers
for i in range(len(dataset[0])):
 	str_column_to_float(dataset, i)
     
X = dataset
y = []
for i in range(len(dataset)):
    y.append(X[i][-1])
    X[i]=X[i][:-1]
plot_path(X, y, 'banknote', 4)

plt.xlim([0,0.1])
plt.show()